{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 手写数字识别(One vs All多分类)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 需要识别的数字是32*32像素的黑白图像"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* trainData大约2000个样本，testData大约是900个测试数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 将图像文本数据转换为向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import operator#运算符模块\n",
    "from os import listdir\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def img2vector(filename):\n",
    "    returnVet = np.zeros((1,1024))\n",
    "    fr = open(filename)\n",
    "    for i in range(32):\n",
    "        lineStr = fr.readline()\n",
    "        for j in range(32):\n",
    "            returnVet[0,32*i+j] = int(lineStr[j])\n",
    "    return returnVet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "testVector = img2vector('testDigits/0_13.txt')\n",
    "print(testVector[0,0:32])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "#计算距离\n",
    "def classify0(inX, dataSet, labels, k):\n",
    "    dataSetSize = dataSet.shape[0]\n",
    "    #距离度量 度量公式为欧氏距离\n",
    "    diffMat = np.tile(inX, (dataSetSize,1)) - dataSet\n",
    "    sqDiffMat = diffMat**2\n",
    "    sqDistances = sqDiffMat.sum(axis=1)\n",
    "    distances = sqDistances**0.5\n",
    "    \n",
    "    #将距离排序：从小到大\n",
    "    sortedDistIndicies = distances.argsort()\n",
    "    #选取前K个最短距离， 选取这K个中最多的分类类别\n",
    "    classCount={}\n",
    "    for i in range(k):\n",
    "        voteIlabel = labels[sortedDistIndicies[i]]\n",
    "        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 \n",
    "    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)\n",
    "    return sortedClassCount[0][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def handwritingClass():\n",
    "    #导入训练数据\n",
    "    hwlabels = []\n",
    "    trainingFileList = listdir('trainingDigits')\n",
    "    m = len(trainingFileList)\n",
    "    trainMat = np.zeros((m,1024))\n",
    "    #hwlabels存储0-9对应的index位置，trainMat存储每个位置对应的图片向量\n",
    "    for i in range(m):\n",
    "        fileName = trainingFileList[i]\n",
    "        fileStr = fileName.split('.')[0]#获取每一个样本,去掉.txt\n",
    "        classNum = int(fileStr.split('_')[0])#去掉_,获取每个样本前的数字\n",
    "        hwlabels.append(classNum)\n",
    "        #32*32矩阵转换为1*1024\n",
    "        trainMat[i,:] = img2vector('trainingDigits/%s' % fileName)\n",
    "        \n",
    "    #导入测试数据\n",
    "    testFileList = listdir('testDigits')  # iterate through the test set\n",
    "    errorCount = 0.0\n",
    "    mTest = len(testFileList)\n",
    "    for i in range(mTest):\n",
    "        fileNameStr = testFileList[i]\n",
    "        fileStr = fileNameStr.split('.')[0]  # take off .txt\n",
    "        classNumStr = int(fileStr.split('_')[0])\n",
    "        vectorUnderTest = img2vector('testDigits/%s' % fileName)\n",
    "        classifierResult = classify0(vectorUnderTest, trainMat, hwlabels, 3)\n",
    "        print (\"the classifier came back with: %d, the real answer is: %d\" % (classifierResult, classNum))\n",
    "        if (classifierResult != classNum): errorCount += 1.0\n",
    "    print (\"\\nthe total number of errors is: %d\" % errorCount)\n",
    "    print (\"\\nthe total error rate is: %f\" % (errorCount / float(mTest)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "handwritingClass()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
